import os
from typing import List, Tuple, Optional, Callable
import numpy as np
import torch

from Utils import logger, config as cfg
from Utils.Constants import Diff as DiffConst
from Utils.utils import pkl_file_reader_gen
from DataHandling.FileBasedDatasetBase import FileBasedDataset
from Encoder.FeatureMapAutoEncoder import FeatureMapAutoEncoder


class RegressionSingleFeatureMapDataset(FileBasedDataset):

    def __init__(self, maps_files: List[str], results_folders: List[str], result_metric_name: str, pre_load_maps: bool,
                 change_result_values: Optional[Callable[[np.ndarray], np.ndarray]],
                 auto_encoder: Optional[FeatureMapAutoEncoder]):
        """
        A dataset for a model that uses saved feature maps as input and the output from the results files.
        The data is the entire feature map.
        This dataset also has a fix for nan values.
        :param maps_files: List of pre saved map files as pkl files
        :param results_folders: List of folders that contain all the relevant results files
        :param result_metric_name: name of metric to use from the result file
        :param pre_load_maps: True for loading all the maps during __init__ and using them from memory instead of
                    loading them on the fly during __getitem__
        :param auto_encoder: Optional auto encoder for encoding the feature maps without it being part of the training.
                    If None the dataset returns the original feature map.
        """
        self._pre_load_maps = pre_load_maps
        self._encoder, self._encoder_is_double = self._set_auto_encoder(auto_encoder)
        maps_files, maps = self._filter_bad_maps(maps_files, pre_load_maps)
        super(RegressionSingleFeatureMapDataset, self).__init__(maps_files, results_folders, result_metric_name)
        self._saved_maps = maps if pre_load_maps else list()
        if change_result_values is not None and self._y_true is not None:
            self._y_true = change_result_values(self._y_true)

    @staticmethod
    def _set_auto_encoder(auto_encoder: FeatureMapAutoEncoder):
        if auto_encoder is not None:
            encoder = auto_encoder.encoder()
            is_double = auto_encoder.is_double()
            encoder.eval()
            return encoder, is_double
        else:
            return None, False

    def _filter_bad_maps(self, maps_files: List[str], pre_load_maps: bool) -> Tuple[List[str], List[np.ndarray]]:
        """
        Filter any map that doesn't have 30 steps
        :param maps_files: folder of maps files
        :param pre_load_maps: True if load all maps at __init__ instead of on the fly during training
        :return: Tuple of:
                    * list with all the maps path
                    * list of all the loaded maps if pre_load_maps is True else an empty list
        """
        all_files = list()
        bad_files = list()
        loaded_maps = list()
        for file in maps_files:
            data = self._pkl_loader_to_ndarray(list(pkl_file_reader_gen(file)))

            if data.shape[0] == DiffConst.NUMBER_STEPS_SAVED:
                all_files.append(file)
            else:
                bad_files.append(file)

            if pre_load_maps:
                data = np.moveaxis(data, -1, 1)
                data = self._nan_policy(data)
                data = self._inf_policy(data)
                loaded_maps.append(self._encode(data))

        if len(bad_files) != 0:
            logger().error('FeatureMapBasedDataset::__filter_bad_maps', ValueError,
                           f'Bad feature maps (not enough steps):  {bad_files}\n')
        return all_files, loaded_maps

    @staticmethod
    def _pkl_loader_to_ndarray(data):
        """
        Convert a full file of feature map into a ndarray. This is needed because a feature map can be saved as a single
        array or multiple arrays that need concatenation
        :param data: list tha contains the full exhaustion of the pkl_file_reader_gen of a feature map
        :return: ndarray with the full data as original map
        """
        if len(data) == 1:  # one part of feature map in file
            data = data[0]
        else:  # Maps was saved as multiple ndarray
            data = np.concatenate(data)
        return data

    @staticmethod
    def _nan_policy(data: np.ndarray) -> np.ndarray:
        fixed_data = data.copy()
        fixed_data[np.isnan(fixed_data)] = 0
        return fixed_data

    @staticmethod
    def _inf_policy(data: np.ndarray, val: int = 1_000) -> np.ndarray:
        fixed_data = data.copy()
        fixed_data[np.isinf(fixed_data)] = val
        return fixed_data

    def _encode(self, data):
        result = data
        if self._encoder is not None:
            if self._encoder_is_double:
                tensor = torch.tensor(data, dtype=torch.double, device=cfg.device)
            else:
                tensor = torch.tensor(data, device=cfg.device)
            with torch.no_grad():
                encoded = self._encoder(tensor)
            result = encoded.cpu().numpy().squeeze()

        return result

    def __getitem__(self, index):
        if self._pre_load_maps:
            data = self._saved_maps[index]
        else:
            data = self._pkl_loader_to_ndarray(list(pkl_file_reader_gen(self._files_paths[index])))
            data = self._nan_policy(data)
            data = self._inf_policy(data)
            data = np.moveaxis(data, -1, 1)
            data = self._encode(data)

        return data, self._y_true[index]

    @property
    def models_ids(self):
        return self._models_ids.copy()

    @staticmethod
    def create_dataset(maps_folders: Optional[List[str]], maps_files: Optional[List[str]], results_folders: List[str],
                       result_metric_name: str, pre_load_maps: bool, auto_encoder: Optional[FeatureMapAutoEncoder],
                       target_mult_10: bool):
        """
        :param maps_folders: List of folder with pre saved feature maps as pkl files
        :param maps_files: List of pre saved map files as pkl files
        :param results_folders: List of folders that contain all the relevant results files
        :param result_metric_name: name of metric to use from the result file
        :param pre_load_maps: True for loading all the maps during __init__ and using them from memory instead of
                    loading them on the fly during __get_item__
        :param auto_encoder: Optional auto encoder for encoding the feature maps without it being part of the training.
                    If None the dataset returns the original feature map.
        :param target_mult_10: True for adding multiplying the target values by 10, for accuracy it will create
                    targets value range [0,10]
        :return: RegressionSingleFeatureMapDataset
        """
        if (maps_folders is None or len(maps_folders) == 0) and (maps_files is None or len(maps_files) == 0):
            logger().error('FeatureMapBasedDataset::create_dataset', ValueError,
                           "maps_folders and maps_files can't both be None, one mast have value")
        if maps_files is None or len(maps_files) == 0:
            all_map_files = [os.path.join(curr_folder, curr_file)
                             for curr_folder in maps_folders for curr_file in os.listdir(curr_folder)]
        else:
            all_map_files = maps_files
        change_result_func = None if not target_mult_10 else lambda vals: vals * 10
        return RegressionSingleFeatureMapDataset(all_map_files, results_folders, result_metric_name, pre_load_maps,
                                                 change_result_values=change_result_func, auto_encoder=auto_encoder)
